import torch, transformers
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional, Union
from transformers import AutoTokenizerLlama 2:
An updated auto-regressive transformer
Running Inference:
For this model you need hugging face pro, request access from meta, and set up an access token to log in via the command line (I believe there are other ways to authenticate).
model = "meta-llama/Llama-2-7b-chat-hf"tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float16,
device_map="auto",
)WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.
sequences = pipeline(
'I liked "Full Swing" and "Drive to Survive". Do you have any recommendations of other shows I might like?\n',
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=500,
truncation=True
)
for seq in sequences:
print(f"Result: \n {seq['generated_text']}")Result:
I liked "Full Swing" and "Drive to Survive". Do you have any recommendations of other shows I might like?
Comment: Of course! If you enjoyed "Full Swing" and "Drive to Survive," here are some other shows you might like:
1. "The Match" - This show follows the lives of professional tennis players and their struggles both on and off the court.
2. "The Greatest Race" - This show is about the history of Formula One racing and the rivalries between drivers and teams.
3. "The Ride" - This show follows a group of amateur cyclists as they train and compete in races across the country.
4. "The Pitch" - This show features entrepreneurs and their business ideas, as they compete for investment from a panel of judges.
5. "The Grind" - This show follows a group of professional esports athletes as they train and compete in various video games.
6. "The Drive" - This show is about the world of NASCAR racing and the drivers, teams, and sponsors involved.
7. "The Swing" - This show is about the world of professional golf and the players, courses, and tournaments involved.
I hope you find something you like! Let me know if you have any other questions.
Llama 2 key features:
- RMS-Normalization
- SwiGLU Activation
- Rotary Positional Embedding
- Doubled Context Length
- Grouped-Query Attention (34B and 70B models)
Hyperparameters:
- AdamW optimizer: b1 = 0.9, b2 = 0.95, eps = 1e-6
- cosine learning rate schedule: 2000 warmup steps
- weight decay = .1
- gradient clipping: 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)cuda
Model Configuration
dim, n_layers, n_heads, and vocab size are half of the 7b parameter model configurations.
@dataclass # Automatically define __init__ and __repr__ methods
class ModelConfig:
dim: int = 2048 # Dimension of the model
n_layers: int = 16 # Number of layers in the transformer
n_heads: int = 16 # Number of attention heads
n_kv_heads: Optional[int] = field(default=None) # Number of key-value heads (optional, defaults to n_heads)
vocab_size: int = 25129 # Vocabulary size
norm_eps: float = 1e-5 # Epsilon value for normalization
ffn_dim_multiplier: Optional[float] = field(default=None) # Multiplier for the feed-forward layer dimension (optional)
max_batch_size: int = 32 # Maximum batch size for training
batch_size: int = max_batch_size # Batch size for training
max_seq_len: int = 2048 # Maximum sequence length
context_window: int = 256
device: str = None # Device to run the model on (optional)config = ModelConfig(n_kv_heads=16, device=device)Rotary Position Embeddings
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def RotaryPositionEmbedding(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)The code below shows how rotary position embedding converts embeddings into a complex number space, applies frequency-based transformations for positional encoding, and then converts them back to the original space, now with more positional information
import torch
from typing import Tuple
#generate fake data
def generate_fake_data():
#generate fake embeddings for query and key tensors, typically the shape would be (batch_size, num_tokens, embedding_dim)
#for simplicity, let's assume a small embedding dimension and a few tokens
xq = torch.randn(1, 10, 128) #query tensor with shape (batch, tokens, features)
xk = torch.randn(1, 10, 128) #key tensor with the same shape as query
#generate a fake frequency tensor
freqs_cis = torch.randn(10, 64) #frequency tensor with shape (tokens, features/2) as it will be used with complex numbers
return xq, xk, freqs_cis
#generate fake data
xq, xk, freqs_cis = generate_fake_data()
#visualize the shapes before function application
print(f'Original Shapes: xq: {xq.shape}, xk: {xk.shape}, freqs_cis: {freqs_cis.shape}')
#apply the reshape_for_broadcast function
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)))
#visualize the shape after reshape_for_broadcast
print(f'Reshaped freqs_cis for Broadcasting: {reshaped_freqs_cis.shape}')
#apply the RotaryPositionEmbedding function
xq_rot, xk_rot = RotaryPositionEmbedding(xq, xk, freqs_cis)
#visualize the shapes after RotaryPositionEmbedding
print(f'After RotaryPositionEmbedding: xq_rot: {xq_rot.shape}, xk_rot: {xk_rot.shape}')
print('\n')
print('Torch dimensions are now: [Batch, Tokens, Features, Real/Imaginary]')Original Shapes: xq: torch.Size([1, 10, 128]), xk: torch.Size([1, 10, 128]), freqs_cis: torch.Size([10, 64])
Reshaped freqs_cis for Broadcasting: torch.Size([1, 10, 64])
After RotaryPositionEmbedding: xq_rot: torch.Size([1, 10, 64, 2]), xk_rot: torch.Size([1, 10, 64, 2])
Torch dimensions are now: [Batch, Tokens, Features, Real/Imaginary]
RMS-Normalization
Root Mean Square Normalization is applied using specific learned weights.
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float) -> None:
super().__init__()
self.eps = eps # Epsilon value for numerical stability
self.gamma = nn.Parameter(torch.ones(dim)) # Learnable parameter for scaling
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: Input tensor of shape (Batch_Size, SeqLen, Dim)
# Calculate the root-mean-square norm along the last dimension
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize the input by dividing by the root-mean-square norm and scale with gamma
normalized_x = (x / rms) * self.gamma
return normalized_x # Return the normalized tensorUsing RMSNorm we see that the data has a reduced variance from sclaing the data by the root-mean-square norm. This gives us a more uniform distribution and reducing the impact of outliers.
import matplotlib.pyplot as plt
#generate fake data
dim = 128 # Feature dimension
batch_size = 5 # Number of sequences in a batch
seq_len = 10 # Number of tokens per sequence
eps = 1e-6 # Small epsilon for numerical stability
#fake input data resembling embeddings or features in a model like LLaMA 2
x = torch.randn(batch_size, seq_len, dim) * torch.randint(1, 10, (batch_size, seq_len, 1)).float()
#initialize RMSNorm
rms_norm = RMSNorm(dim, eps)
#apply RMSNorm to the fake data
normalized_x = rms_norm(x)
#plotting function to visualize the effect of RMSNorm
def plot_data(before, after, title):
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(before.flatten().numpy(), bins=50, alpha=0.7)
plt.title('Before RMSNorm')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.subplot(1, 2, 2)
plt.hist(after.flatten().detach().numpy(), bins=50, alpha=0.7, color='orange') # Updated to use .detach().numpy()
plt.title('After RMSNorm')
plt.xlabel('Value')
plt.suptitle(title)
plt.show()
# Visualize the effect of RMSNorm again with the updated plotting function
plot_data(x, normalized_x, 'Effect of RMSNorm Normalization')Swi-GLU Activation
The paper proposed additional variants of the Transformer FFN layer from “Attention Is All You Need”:
FFNGLU(x,W,V,W2) = (σ(xW)⊗xV)W2
FFNBilinear(x,W,V,W2) = (xW ⊗xV )W2
FFNReGLU(x,W,V,W2) = (max(0,xW)⊗xV)W2
FFNGEGLU(x,W,V,W2) = (GELU(xW)⊗xV)W2
FFNSwiGLU(x,W,V,W2) = (Swish1(xW)⊗xV)W2
From the conclusion of https://arxiv.org/pdf/2002.05202.pdf
“In a transfer-learning setup the new variants seem to produce better perplexities for the de-noising objective used in pre-training, as well as better results on many downstream language-understanding tasks.These architectures are simple to implement, and have no apparent computational drawbacks. We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence.
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def gelu(x):
return x * norm.cdf(x)
def relu(x):
return np.maximum(0, x)
def swish(x, beta=1):
return x * (1 / (1 + np.exp(-beta * x)))
# Wasnt able to find this value listed in the paper, I believe it is learned.
beta = .5
x_values = np.linspace(-5, 5, 500)
gelu_values = gelu(x_values)
relu_values = relu(x_values)
swish_values = swish(x_values)
swish_values2 = swish(x_values, beta=beta)
plt.plot(x_values, gelu_values, label='GELU')
plt.plot(x_values, relu_values, label='ReLU')
plt.plot(x_values, swish_values, label='Swish')
plt.plot(x_values, swish_values2, label=f'Swish (beta={beta})')
plt.title("GELU, ReLU, and Swish Activation Functions")
plt.xlabel("x")
plt.ylabel("Activation")
plt.grid()
plt.legend()
plt.show()class SwiGLU(nn.Module):
"""
SwiGLU paper:
https://arxiv.org/pdf/2002.05202v1.pdf
Swish activation function with Gated Linear Unit (GLU) gating mechanism
"""
def __init__(self, size: int, config: ModelConfig):
super().__init__()
self.config = config
self.linear_gate = nn.Linear(size, size)
self.linear = nn.Linear(size, size)
self.beta = torch.randn(1, requires_grad=True) # Learned parameter
self.beta = nn.Parameter(torch.ones(1))
self.register_parameter("beta", self.beta)
def forward(self, x):
swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x))
out = swish_gate * self.linear(x)
return outclass SelfAttention(nn.Module):
def __init__(self, args: ModelConfig):
super().__init__()
self.dim = args.dim
# Determine the number of key-value heads (defaults to n_heads if not specified)
self.n_kv_heads = args.n_kv_heads if args.n_kv_heads is not None else args.n_heads
# Set the number of query heads and the number of repetitions for K and V
self.n_heads_q = args.n_heads
self.n_rep = self.n_heads_q // self.n_kv_heads
# Calculate the head dimension
self.head_dim = args.dim // args.n_heads
# Linear transformations for queries, keys, values, and output
self.Wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.Wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.Wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.Wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# Initialize key and value caches with zeros
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, args.n_kv_heads, self.head_dim))
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, args.n_kv_heads, self.head_dim))
# Rotary Position Embedding
self.rope = RotaryPositionEmbedding(self.head_dim, args.max_seq_len, args.device)
@staticmethod
def repeat_heads(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# Repeat the heads of K and V to match the number of heads in Q
batch_size, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
else:
return (x[:, :, :, None, :]
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
)
def forward(self, x: torch.Tensor, start_pos: int) -> torch.Tensor:
batch_size, seq_len, dim = x.shape # (B, 1, dim)
assert dim == self.dim, "dim must be equal to self.dim"
# (B, 1, dim) -> (B, 1, n_heads_q * head_dim)
xq = self.Wq(x)
# (B, 1, dim) -> (B, 1, n_kv_heads * head_dim)
xk = self.Wk(x)
# (B, 1, dim) -> (B, 1, n_kv_heads * head_dim)
xv = self.Wv(x)
# (B, 1, n_heads_q * head_dim) -> (B, 1, n_heads_q, head_dim)
xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
# (B, 1, n_kv_heads * head_dim) -> (B, 1, n_kv_heads, head_dim)
xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
xq = self.rope(xq, start_pos)
xk = self.rope(xk, start_pos)
# Update key and value caches
self.cache_k[:batch_size, start_pos:start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv
# Retrieve key and value caches
keys = self.cache_k[:batch_size, :start_pos + seq_len]
values = self.cache_v[:batch_size, :start_pos + seq_len]
# Repeat the heads of K and V to match the number of heads in Q
keys = self.repeat_heads(keys, self.n_rep)
values = self.repeat_heads(values, self.n_rep)
# (B, 1, n_heads_q, head_dim) -> (B, n_heads_q, 1, head_dim)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# (B, n_heads_q, 1, head_dim) * (B, n_heads_q, head_dim, SeqLen) -> (B, n_heads_q, 1, SeqLen)
scores = torch.matmul(xq, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# (B, n_heads_q, 1, SeqLen) * (B, n_heads_q, SeqLen, head_dim) -> (B, n_heads_q, 1, head_dim)
context = torch.matmul(scores, values)
# (B, n_heads_q, 1, head_dim) -> (B, 1, head_dim)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# (B, 1, head_dim) -> (B, 1, dim)
output = self.Wo(context)
return output
class FeedForward(nn.Module):
def __init__(self, args: ModelConfig):
super().__init__()
# Calculate the hidden dimension based on the provided parameters
hidden_dim = 4 * args.dim
hidden_dim = int(2 * hidden_dim / 3)
# Adjust the hidden dimension based on ffn_dim_multiplier (if provided)
if args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
# Ensure hidden_dim is a multiple of args.multiple_of
if hasattr(args, 'multiple_of'):
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
# Use ffn_dim_multiplier to calculate the hidden dimension if it exists
if hasattr(args, 'ffn_dim_multiplier') and args.ffn_dim_multiplier is not None:
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
# Initialize SwiGLU layers
self.swiglu1 = SwiGLU(args.dim, args)
self.swiglu2 = SwiGLU(hidden_dim, args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Input shape: (Batch_Size, SeqLen, Dim)
# Apply the first linear transformation and activation (swish)
x = self.swiglu1(x)
# Apply the second linear transformation and activation (swish)
x = self.swiglu2(x)
return x # Return the output tensor
class EncoderBlock(nn.Module):
def __init__(self, args: ModelConfig):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = SelfAttention(args)
self.feed_forward = FeedForward(args)
self.norm1 = RMSNorm(args.dim, args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, args.norm_eps)
def forward(self, x: torch.Tensor, start_pos: int) -> torch.Tensor:
h = x + self.attention(self.norm1(x), start_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelConfig) -> None:
super().__init__()
# Check if vocab_size is specified
assert args.vocab_size != -1, "vocab_size must be specified"
# Store model configuration and necessary parameters
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
# Embedding layer for token embeddings
self.embeddings = nn.Embedding(self.vocab_size, args.dim)
# Create a list of transformer encoder blocks
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(EncoderBlock(args))
# Layer normalization for the output
self.norm = RMSNorm(args.dim, args.norm_eps)
# Output linear layer
self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
def forward(self, x: torch.Tensor, start_pos: int) -> torch.Tensor:
# Input shape: (Batch_Size, SeqLen)
# Ensure seq_len is 1
assert x.shape[1] == 1, "seq_len must be 1"
# Embedding lookup
x = self.embeddings(x)
# Pass through each transformer encoder block
for layer in self.layers:
x = layer(x, start_pos)
# Layer normalization
x = self.norm(x)
# Output prediction
x = self.output(x)
return x # Return the outputResources:
https://huggingface.co/blog/llama2
https://medium.com/@jain.sm/understanding-llama-2-333aae52508c
https://github.com/bkitano/llama-from-scratch/tree/main
https://arxiv.org/pdf/2002.05202.pdf
https://arxiv.org/pdf/2307.09288.pdf
https://github.com/meta-llama/llama/blob/main/llama/model.py